之前对Pytorch的索引方式一直有点疑惑,昨天在小伙伴的帮助下对其有了更加深刻的理解。下面对这些进行一下总结。另外,值得注意的是,Pytorch号称直接对接的Numpy,因此下面的索引方法理论上也可以适用于Numpy的索引方式。
Pytorch的tensor索引方式有三种:分别为按照long tensor、按照bool tensor和按照byte tensor。下面分别进行介绍。
首先,说明一下,Pytorch默认打印的tensor只有四位小数,可以使用torch.set_printoptions(precision=8)多打印几个小数。
long tensor
首先看下面的代码。
1 | import torch |
代码解释:当b为long tensor时候,a[b]的实质为a[b, :],也就是取出b中元素作为a的行索引,而默认取出所有列。这可以使用下面代码解释:
1 | a[0] |
在这个代码里面,a[0]与a[0, :]的输出是一致的,所以也就是a{0]中的0作为了a的行索引,而默认取出所有列。同理,a[[0,1,1],:]中的[0,1,1]分别作为了a的行索引。
bool tensor和byte tensor
对于bool tensor,首先看下面的代码。
1 | import torch |
代码解释:当b为bool tensor时候,b中每一个位置的bool值表示是否取a对应位置的值,当为True的时候表示取出该值,当为False的时候,表示不取该值。
对于byte tensor,可以看下面代码:
1 | c = b.byte() |
可以看出来,bool tensor和byte tensor作为索引列表时效果是一样的,只是不推荐使用byte tensor而已。

